# -*- coding: utf-8 -*-
"""
Created on Fri Jan 13 20:19:26 2023

@author:ikoreed
"""

import numpy as np

file_dir_A='/A/'
file_dir_B='/B/'
save_dir='/'

batch_str='64x64_50um'
save_str='64x64_50um'

num_files=154
start_ind=0
file_size=1024
batch_size=8
half_batch=batch_size//2
steps=file_size//(batch_size*2)
image_normalization=False

for nf in range(start_ind,num_files//2):
    print(nf)
    
    file_in_A=file_dir_A+batch_str+'_Bxyz_'+str(nf).zfill(3)+'.npy'
    file_out_A=file_dir_A+batch_str+'_Jxy_'+str(nf).zfill(3)+'.npy'
    file_z_A = file_dir_A+'standoff_distances_'+str(nf).zfill(3)+'.npy'
    
    data_in_A=np.load(file_in_A)
    data_out_A=np.load(file_out_A)
    data_z_A=np.load(file_z_A)
    
    file_in_B=file_dir_B+batch_str+'_Bxyz_'+str(nf).zfill(3)+'.npy'
    file_out_B=file_dir_B+batch_str+'_Jxy_'+str(nf).zfill(3)+'.npy'
    file_z_B = file_dir_B+'standoff_distances'+str(nf).zfill(3)+'.npy'

    data_in_B=np.load(file_in_B)
    data_out_B=np.load(file_out_B)
    data_z_B=np.load(file_z_B)
    
    data_in=np.zeros([np.shape(data_in_A)[0],np.shape(data_in_A)[1],np.shape(data_in_A)[2],3])
    data_out=np.zeros(np.shape(data_out_A))
    data_z = np.zeros(file_size)
    
    for i in range(steps):
        
        data_in[2*i*batch_size:(2*i+1)*batch_size,:,:,:]=data_in_A[i*batch_size:(i+1)*batch_size,:,:,:3]
        data_out[2*i*batch_size:(2*i+1)*batch_size,:,:,:]=data_out_A[i*batch_size:(i+1)*batch_size,:,:,:3]
        data_z[2*i*batch_size:(2*i+1)*batch_size]=data_z_A[i*batch_size:(i+1)*batch_size]
        
        data_in[(2*i+1)*batch_size:(2*i+2)*batch_size,:,:,:]=data_in_B[i*batch_size:(i+1)*batch_size,:,:,:3]
        data_out[(2*i+1)*batch_size:(2*i+2)*batch_size,:,:,:]=data_out_B[i*batch_size:(i+1)*batch_size,:,:,:3]
        data_z[(2*i+1)*batch_size:(2*i+2)*batch_size]=data_z_B[i*batch_size:(i+1)*batch_size]
    
    if image_normalization:

        vect_norm=(data_in[:,:,:,0]**2+data_in[:,:,:,1]**2+data_in[:,:,:,2]**2)**0.5
        vect_norm=np.max(vect_norm,axis=1)
        vect_norm=np.max(vect_norm,axis=1)

        for j in range(file_size):
            if vect_norm[j]!=0:
                data_in[j,:,:,:]=data_in[j,:,:,:]/vect_norm[j]
                data_out[j,:,:,:]=data_out[j,:,:,:]/vect_norm[j]

    file_in=save_dir+save_str+'_Bxyz_'+str(nf).zfill(3)+'.npy'
    file_out=save_dir+save_str+'_Jxy_'+str(nf).zfill(3)+'.npy'
    file_z = save_dir+'standoff_distances_'+str(nf).zfill(3)+'.npy'
    
    np.save(file_in,data_in)
    np.save(file_out,data_out)
    np.save(file_z,data_z)

    ##second file

    file_in_B=file_dir_B+batch_str+'_Bxyz_'+str(nf+num_files).zfill(3)+'.npy'
    file_out_B=file_dir_B+batch_str+'_Jxy_'+str(nf+num_files).zfill(3)+'.npy'
    file_z_B = file_dir_B+'standoff_distances'+str(nf+num_files).zfill(3)+'.npy'

    data_in_B=np.load(file_in_B)
    data_out_B=np.load(file_out_B)
    data_z_B=np.load(file_z_B)

    for i in range(steps):
        
        data_in[2*i*batch_size:(2*i+1)*batch_size,:,:,:]=data_in_A[512+i*batch_size:512+(i+1)*batch_size,:,:,:3]
        data_out[2*i*batch_size:(2*i+1)*batch_size,:,:,:]=data_out_A[512+i*batch_size:512+(i+1)*batch_size,:,:,:3]
        data_z[2*i*batch_size:(2*i+1)*batch_size]=data_z_A[512+i*batch_size:512+(i+1)*batch_size]
        
        data_in[(2*i+1)*batch_size:(2*i+2)*batch_size,:,:,:]=data_in_B[i*batch_size:(i+1)*batch_size,:,:,:3]
        data_out[(2*i+1)*batch_size:(2*i+2)*batch_size,:,:,:]=data_out_B[i*batch_size:(i+1)*batch_size,:,:,:3]
        data_z[(2*i+1)*batch_size:(2*i+2)*batch_size]=data_z_B[i*batch_size:(i+1)*batch_size]
        
    
    if image_normalization:

        vect_norm=(data_in[:,:,:,0]**2+data_in[:,:,:,1]**2+data_in[:,:,:,2]**2)**0.5
        vect_norm=(data_in[:,:,:,0]**2+data_in[:,:,:,1]**2+data_in[:,:,:,2]**2)**0.5
        vect_norm=np.max(vect_norm,axis=1)
        vect_norm=np.max(vect_norm,axis=1)

        for j in range(file_size):
            if vect_norm[j]!=0:
                data_in[j,:,:,:]=data_in[j,:,:,:]/vect_norm[j]
                data_out[j,:,:,:]=data_out[j,:,:,:]/vect_norm[j]
    
    file_in=save_dir+save_str+'_Bxyz_'+str(nf+num_files).zfill(3)+'.npy'
    file_out=save_dir+save_str+'_Jxy_'+str(nf+num_files).zfill(3)+'.npy'
    file_z = save_dir+'standoff_distances_'+str(nf+num_files).zfill(3)+'.npy'
    
    np.save(file_in,data_in)
    np.save(file_out,data_out)
    np.save(file_z,data_z)

   